{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f05fa329",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.models import ModelXtoCResNet\n",
    "from src.preprocessing.RIVAL10 import preprocessing_rival10\n",
    "from src.utils import *\n",
    "from src.training import run_epoch_x_to_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f239f5e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 26384 unique images.\n",
      "Found 18 unique concepts.\n",
      "Generated one-hot training matrix with shape: (21098, 10)\n",
      "Found 21098 images.\n",
      "Processing in 330 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|█████████████████████| 330/330 [00:32<00:00, 10.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 21098 images.\n",
      "Found 5286 images.\n",
      "Processing in 83 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|███████████████████████| 83/83 [00:11<00:00,  7.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 5286 images.\n",
      "Dataset initialized with 21098 pre-sorted items.\n",
      "Dataset initialized with 5286 pre-sorted items.\n"
     ]
    }
   ],
   "source": [
    "concept_labels, train_loader, val_loader, test_loader = preprocessing_rival10(training=True, class_concepts=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c80140bb",
   "metadata": {},
   "source": [
    "# Training Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ae643c1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import find_class_imbalance\n",
    "from src.config import RIVAL10_CONFIG as config_dict\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3d7846bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRIMMED_CONCEPTS = config_dict['N_TRIMMED_CONCEPTS']\n",
    "N_CLASSES = config_dict['N_CLASSES']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aaf19306",
   "metadata": {},
   "source": [
    "**Find device to run model on (CPU or GPU).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5316e340",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available()\n",
    "                    else \"mps\" if torch.backends.mps.is_available()\n",
    "                    else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76b5b0a8",
   "metadata": {},
   "source": [
    "**Instantiate the model.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "82fc53a2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model Instantiated (X -> C)\n"
     ]
    }
   ],
   "source": [
    "model = ModelXtoCResNet(pretrained=True,\n",
    "                freeze=True,\n",
    "                n_concepts=N_TRIMMED_CONCEPTS)\n",
    "\n",
    "model = model.to(device)\n",
    "print(\"Model Instantiated (X -> C)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4731e539",
   "metadata": {},
   "source": [
    "### Loss\n",
    "We use weighted loss.\n",
    "\n",
    "`BCEWithLogitsLoss()` performs 2 steps:\n",
    "1. $\\sigma(x)$\n",
    "    - Applies the sigmoid function to the logits to get probabilities.\n",
    "2. $\\text{BCE}(\\sigma(x), y) = y \\cdot \\text{log}(\\sigma(x)) + (1-y) \\cdot (1-\\text{log}(\\sigma(x)))$\n",
    "    - Compute binary cross-entropy between output probabilities ($\\sigma(x)$) and ground truths ($y$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "be2ad09f",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_weighted_loss = True # Set to False for simple unweighted loss\n",
    "\n",
    "if use_weighted_loss:\n",
    "    concept_weights = find_class_imbalance(concept_labels)\n",
    "    attr_criterion = [nn.BCEWithLogitsLoss(weight=torch.tensor([ratio], device=device, dtype=torch.float))\n",
    "                    for ratio in concept_weights]\n",
    "else:\n",
    "    attr_criterion = [nn.BCEWithLogitsLoss() for _ in range(N_TRIMMED_CONCEPTS)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "936c945c",
   "metadata": {},
   "source": [
    "### Optimiser\n",
    "Use same settings as used in CBM repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b50f189a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizer and Scheduler Ready\n"
     ]
    }
   ],
   "source": [
    "lr = 0.01\n",
    "weight_decay = 0.00004 # same as lambda in L2-regularisation\n",
    "\n",
    "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),\n",
    "                    lr=lr,\n",
    "                    momentum=0.9,\n",
    "                    weight_decay=weight_decay)\n",
    "\n",
    "# scheduler_step = n -> decrease the LR every n epochs\n",
    "scheduler_step = 1000\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)\n",
    "\n",
    "print(\"Optimizer and Scheduler Ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e648c53",
   "metadata": {},
   "source": [
    "### Training and Validation Loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5a03d3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 50\n",
    "log_interval = 50\n",
    "\n",
    "best_val_acc = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "743f1447",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Epoch 1/50 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Train Summary | Loss: 0.3805 | Acc: 96.811\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Val Summary   | Loss: 0.1895 | Acc: 98.525\n",
      "Validation accuracy improved (0.000 -> 98.525). Saving model...\n",
      "Model saved to x_to_c_best_model.pth\n",
      "Current LR: 0.01\n",
      "--- Epoch 2/50 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Traceback (most recent call last):                                              \n",
      "  File \"<string>\", line 1, in <module>\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/spawn.py\", line 122, in spawn_main\n",
      "    exitcode = _main(fd, parent_sentinel)\n",
      "               ^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/spawn.py\", line 132, in _main\n",
      "    self = reduction.pickle.load(from_parent)\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/multiprocessing/reductions.py\", line 560, in rebuild_storage_filename\n",
      "    storage = torch.UntypedStorage._new_shared_filename_cpu(manager, handle, size)\n",
      "              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "KeyboardInterrupt\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[10], line 5\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m--- Epoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m ---\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      4\u001b[0m \u001b[38;5;66;03m# Train\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m train_loss, train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mrun_epoch_x_to_c\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr_criterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_training\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_aux\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_concepts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mN_TRIMMED_CONCEPTS\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Train Summary | Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m | Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m      8\u001b[0m \u001b[38;5;66;03m# Validate\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/src/training/phase_1_training.py:50\u001b[0m, in \u001b[0;36mrun_epoch_x_to_c\u001b[0;34m(model, loader, criterion_list, optimizer, n_concepts, is_training, use_aux, device, verbose, return_outputs)\u001b[0m\n\u001b[1;32m     47\u001b[0m desc \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_training \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValidation\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m     48\u001b[0m tqdm_loader \u001b[38;5;241m=\u001b[39m tqdm(loader, desc\u001b[38;5;241m=\u001b[39mdesc, leave\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, disable\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m verbose) \u001b[38;5;66;03m# Wrap the loader with tqdm\u001b[39;00m\n\u001b[0;32m---> 50\u001b[0m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtqdm_loader\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m     51\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     52\u001b[0m \u001b[43m    \u001b[49m\u001b[43mconcept_labels\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/tqdm/std.py:1181\u001b[0m, in \u001b[0;36mtqdm.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1178\u001b[0m time \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_time\n\u001b[1;32m   1180\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1181\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m   1182\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\n\u001b[1;32m   1183\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# Update and possibly print the progressbar.\u001b[39;49;00m\n\u001b[1;32m   1184\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;66;43;03m# Note: does not call self.update(1) for speed optimisation.\u001b[39;49;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py:491\u001b[0m, in \u001b[0;36mDataLoader.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    489\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterator\n\u001b[1;32m    490\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 491\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_iterator\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py:422\u001b[0m, in \u001b[0;36mDataLoader._get_iterator\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    420\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    421\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_worker_number_rationality()\n\u001b[0;32m--> 422\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_MultiProcessingDataLoaderIter\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1146\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter.__init__\u001b[0;34m(self, loader)\u001b[0m\n\u001b[1;32m   1139\u001b[0m w\u001b[38;5;241m.\u001b[39mdaemon \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m   1140\u001b[0m \u001b[38;5;66;03m# NB: Process.start() actually take some time as it needs to\u001b[39;00m\n\u001b[1;32m   1141\u001b[0m \u001b[38;5;66;03m#     start a process and pass the arguments over via a pipe.\u001b[39;00m\n\u001b[1;32m   1142\u001b[0m \u001b[38;5;66;03m#     Therefore, we only add a worker to self._workers list after\u001b[39;00m\n\u001b[1;32m   1143\u001b[0m \u001b[38;5;66;03m#     it started, so that we do not call .join() if program dies\u001b[39;00m\n\u001b[1;32m   1144\u001b[0m \u001b[38;5;66;03m#     before it starts, and __del__ tries to join but will get:\u001b[39;00m\n\u001b[1;32m   1145\u001b[0m \u001b[38;5;66;03m#     AssertionError: can only join a started process.\u001b[39;00m\n\u001b[0;32m-> 1146\u001b[0m \u001b[43mw\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstart\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1147\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_index_queues\u001b[38;5;241m.\u001b[39mappend(index_queue)\n\u001b[1;32m   1148\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_workers\u001b[38;5;241m.\u001b[39mappend(w)\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/process.py:121\u001b[0m, in \u001b[0;36mBaseProcess.start\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    118\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _current_process\u001b[38;5;241m.\u001b[39m_config\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdaemon\u001b[39m\u001b[38;5;124m'\u001b[39m), \\\n\u001b[1;32m    119\u001b[0m        \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdaemonic processes are not allowed to have children\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m    120\u001b[0m _cleanup()\n\u001b[0;32m--> 121\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_popen \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_Popen\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    122\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sentinel \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_popen\u001b[38;5;241m.\u001b[39msentinel\n\u001b[1;32m    123\u001b[0m \u001b[38;5;66;03m# Avoid a refcycle if the target function holds an indirect\u001b[39;00m\n\u001b[1;32m    124\u001b[0m \u001b[38;5;66;03m# reference to the process object (see bpo-30775)\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/context.py:224\u001b[0m, in \u001b[0;36mProcess._Popen\u001b[0;34m(process_obj)\u001b[0m\n\u001b[1;32m    222\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m    223\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_Popen\u001b[39m(process_obj):\n\u001b[0;32m--> 224\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_context\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_context\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mProcess\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_Popen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/context.py:288\u001b[0m, in \u001b[0;36mSpawnProcess._Popen\u001b[0;34m(process_obj)\u001b[0m\n\u001b[1;32m    285\u001b[0m \u001b[38;5;129m@staticmethod\u001b[39m\n\u001b[1;32m    286\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_Popen\u001b[39m(process_obj):\n\u001b[1;32m    287\u001b[0m     \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpopen_spawn_posix\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Popen\n\u001b[0;32m--> 288\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mPopen\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/popen_spawn_posix.py:32\u001b[0m, in \u001b[0;36mPopen.__init__\u001b[0;34m(self, process_obj)\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, process_obj):\n\u001b[1;32m     31\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fds \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m---> 32\u001b[0m     \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/popen_fork.py:19\u001b[0m, in \u001b[0;36mPopen.__init__\u001b[0;34m(self, process_obj)\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturncode \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfinalizer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_launch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprocess_obj\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/popen_spawn_posix.py:62\u001b[0m, in \u001b[0;36mPopen._launch\u001b[0;34m(self, process_obj)\u001b[0m\n\u001b[1;32m     60\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msentinel \u001b[38;5;241m=\u001b[39m parent_r\n\u001b[1;32m     61\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(parent_w, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mwb\u001b[39m\u001b[38;5;124m'\u001b[39m, closefd\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m---> 62\u001b[0m         f\u001b[38;5;241m.\u001b[39mwrite(fp\u001b[38;5;241m.\u001b[39mgetbuffer())\n\u001b[1;32m     63\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     64\u001b[0m     fds_to_close \u001b[38;5;241m=\u001b[39m []\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):\n",
    "    print(f\"--- Epoch {epoch+1}/{epochs} ---\")\n",
    "\n",
    "    # Train\n",
    "    train_loss, train_acc = run_epoch_x_to_c(model, train_loader, attr_criterion, optimizer, is_training=True, use_aux=True, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)\n",
    "    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')\n",
    "\n",
    "    # Validate\n",
    "    # if test_loader:\n",
    "    #     with torch.no_grad():\n",
    "    #         val_loss, val_acc = run_epoch_x_to_c(model, test_loader, attr_criterion, optimizer, n_concepts=N_TRIMMED_CONCEPTS, device=device, verbose=True)\n",
    "\n",
    "    #     print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')\n",
    "\n",
    "    #     # Save best model based on validation accuracy\n",
    "    #     if val_acc > best_val_acc:\n",
    "    #         print(f\"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...\")\n",
    "    #         best_val_acc = val_acc\n",
    "    #         torch.save(model, 'x_to_c_best_model.pth')\n",
    "    #         print(\"Model saved to x_to_c_best_model.pth\")\n",
    "\n",
    "    # Scheduler step\n",
    "    scheduler.step()\n",
    "    print(f\"Current LR: {optimizer.param_groups[0]['lr']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28684261-09da-40ca-bf12-a13f99b29d04",
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_loader:\n",
    "    with torch.no_grad():\n",
    "\n",
    "        test_loss, test_acc, outputs = run_epoch_x_to_c(\n",
    "            model, test_loader, attr_criterion, optimizer=None, n_concepts=N_TRIMMED_CONCEPTS, device=device,\n",
    "            return_outputs='sigmoid', verbose=True\n",
    "        )\n",
    "\n",
    "# print(f\"Shuffled labels shape: {shuffled_img_labels.shape}\")\n",
    "print(f'Best Model Summary   | Loss: {test_loss:.4f} | Acc: {test_acc:.3f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
